fix(data): add bounds check for out-of-range indices in Batch.get_data()#49
Conversation
Batch.get_data() adjusts negative indices via `idx = num_graphs + idx` but does not validate the result. When the adjusted index is still negative (e.g. get_data(-5) on a 3-graph batch), Python's negative indexing silently wraps around, returning node-level data from one graph mixed with system-level data from another — silent data corruption. Add a bounds check after the adjustment to raise IndexError for any index outside [0, num_graphs).
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Greptile SummaryThis PR adds a bounds check to Important Files Changed
Reviews (3): Last reviewed commit: "Merge branch 'main' into fix/get-data-bo..." | Re-trigger Greptile |
|
I posted a fix in #50: if that gets merged in before this PR then we'll see if the workflow runs as intended. Otherwise the tests pass like I said, so I would be fine with merging after you address my comment |
Validate bounds before adjusting negative indices so the error message shows the caller's original value instead of the intermediate adjusted index.
Description
Batch.get_data()adjusts negative indices viaidx = num_graphs + idxbut does not validate the result. When the adjusted index is still negative (e.g.get_data(-5)on a 3-graph batch), Python's negative indexing silently wraps around on the internal storage tensors. Node-level data (positions, forces) may come from one graph while system-level data (energies, cell) comes from a different graph, returning inconsistent data.Reproduction:
Type of Change
Changes Made
Batch.get_data()to raiseIndexErrorfor any index outside[0, num_graphs)test_get_data_out_of_range_raisesTesting
make pytest)make lint)Checklist